from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import logging
import time
from pathlib import Path

import torch
import torch.optim as optim

from lib.core.config import get_model_name
from torch.utils.data.distributed import  DistributedSampler
import torch.distributed as dist

import lib.dataset as dataset


def create_logger(cfg, cfg_name, phase='train'):
    root_output_dir = Path(cfg.OUTPUT_DIR)
    # set up logger
    if not root_output_dir.exists():
        print('=> creating {}'.format(root_output_dir))
        root_output_dir.mkdir()

    dataset = '_'.join([d.DATASET for d in cfg.DATASET.TRAIN_DATASET])
    model, _ = get_model_name(cfg)
    cfg_name = os.path.basename(cfg_name).split('.')[0]

    final_output_dir = root_output_dir / dataset / model / cfg_name

    print('=> creating {}'.format(final_output_dir))
    final_output_dir.mkdir(parents=True, exist_ok=True)

    time_str = time.strftime('%Y-%m-%d-%H-%M')
    log_file = '{}_{}_{}.log'.format(cfg_name, time_str, phase)
    final_log_file = final_output_dir / log_file
    head = '%(asctime)-15s %(message)s'
    logging.basicConfig(filename=str(final_log_file),
                        format=head)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    console = logging.StreamHandler()
    logging.getLogger('').addHandler(console)

    tensorboard_log_dir = Path(cfg.LOG_DIR) / dataset / model / \
        (cfg_name + '_' + time_str)
    print('=> creating {}'.format(tensorboard_log_dir))
    tensorboard_log_dir.mkdir(parents=True, exist_ok=True)

    return logger, str(final_output_dir), str(tensorboard_log_dir)


def get_optimizer(cfg, model):
    optimizer = None
    if cfg.TRAIN.OPTIMIZER == 'sgd':
        optimizer = optim.SGD(
            model.parameters(),
            lr=cfg.TRAIN.LR,
            momentum=cfg.TRAIN.MOMENTUM,
            weight_decay=cfg.TRAIN.WD,
            nesterov=cfg.TRAIN.NESTEROV
        )
    elif cfg.TRAIN.OPTIMIZER == 'adam':
        optimizer = optim.Adam(
            model.parameters(),
            lr=cfg.TRAIN.LR
        )

    return optimizer


def save_checkpoint(states, is_best, output_dir,
                    filename='checkpoint.pth.tar'):
    torch.save(states, os.path.join(output_dir, filename))
    if is_best and 'state_dict' in states:
        torch.save(states['state_dict'],
                   os.path.join(output_dir, 'model_best.pth.tar'))


def load_checkpoint(model, optimizer, output_dir, filename='checkpoint.pth.tar'):
    file = os.path.join(output_dir, filename)
    logger = logging.getLogger()
    if os.path.isfile(file):
        checkpoint = torch.load(file, map_location=torch.device('cpu'))
        start_epoch = checkpoint['epoch']
        model.module.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if 'iteration' not in checkpoint:
            iteration = 1
        else:
            iteration = checkpoint['iteration']
        logger.info('=> load checkpoint {} (epoch {})'
              .format(file, start_epoch))

        return start_epoch, model, optimizer, iteration

    else:
        logger.info('=> no checkpoint found at {}'.format(file))
        return 0, model, optimizer, 1


def get_training_set(config, transform):
    # import pdb;pdb.set_trace()
    if len(config.DATASET.TRAIN_DATASET) <= 1:
        return eval('dataset.' + config.DATASET.TRAIN_DATASET[0].DATASET)(
            config,
            config.DATASET.TRAIN_DATASET[0],
            True,
            transform
        )
    else:
        return eval('dataset.composite')(
            config,
            config.DATASET.TRAIN_DATASET,
            True,
            transform
        )


def get_training_loader(trainset, config):
    if trainset.dataset_type == 'mixed' and config.DATASET.IF_SAMPLE:
        assert 0, 'weightd distributed sampler not implemented so far!'
        from torch.utils.data.sampler import WeightedRandomSampler
        # h36m + mpii
        weights = [config.DATASET.H36M_WEIGHT for _ in range(trainset.h36m_group_size)]
        weights += [config.DATASET.MPII_WEIGHT for _ in range(trainset.mpii_group_size)]
        assert len(weights) == len(trainset)
        my_sampler = WeightedRandomSampler(weights, num_samples=len(trainset))
        train_loader = torch.utils.data.DataLoader(
            trainset,
            batch_size=config.TRAIN.BATCH_SIZE,  # distributed, no need to multiply len_gpus
            sampler=my_sampler,
            num_workers=config.WORKERS,
            pin_memory=False,
            drop_last=True)
    else:
        my_sampler = DistributedSampler(trainset)
        train_loader = torch.utils.data.DataLoader(
            trainset,
            batch_size=config.TRAIN.BATCH_SIZE,
            shuffle=False,
            sampler=my_sampler,
            num_workers=int(config.WORKERS / dist.get_world_size()),
            pin_memory=False,
            drop_last=True)
    return train_loader, my_sampler

